import tqdm
from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate
from utils.metrics import metric
import torch
import torch.nn as nn
from torch import optim
import os
import re  
import time
import warnings
import numpy as np
import json
from sklearn.metrics import r2_score
import shap
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from utils.tools import EarlyStopping, adjust_learning_rate
from data_provider.data_loader import Dataset_Custom
from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import pandas as pd  # 导入pandas库
class Exp_Long_Term_Forecast(Exp_Basic):
    def __init__(self, args):
        super(Exp_Long_Term_Forecast, self).__init__(args)
        self.args = args  # 确保 args 被正确传递

    def _build_model(self):
        model = self.model_dict[self.args.model].Model(self.args).float()

        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        if self.args.loss == 'MSE' or self.args.loss == 'mse':
            criterion = nn.MSELoss()
        elif self.args.loss == 'MAE' or self.args.loss == 'mae':
            criterion = nn.L1Loss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            preds=[]
            trues=[]
            for i, (batch_x, batch_y) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device,non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:,:].float()
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                else:

                    outputs = self.model(batch_x)
                pred = outputs.detach().cpu().numpy()
                true = batch_y.detach().numpy()
                preds.append(pred)
                trues.append(true)
        if len(preds)>0:
            preds=np.concatenate(preds, axis=0)
            trues=np.concatenate(trues, axis=0)
        else:
            preds=preds[0]
            trues=trues[0]
        mse,mae= metric(preds, trues)
        vali_loss=mae if criterion == 'MAE' or criterion == 'mae' else mse
        self.model.train()
        torch.cuda.empty_cache()
        return vali_loss

    def train(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad(set_to_none=True)
                batch_x = batch_x.float().to(self.device,non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:,:].float().to(self.device,non_blocking=True)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    outputs = self.model(batch_x)
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())
                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()
                torch.cuda.empty_cache()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss= self.vali(vali_data, vali_loader, self.args.loss)
            test_loss = self.vali(test_data, test_loader, self.args.loss)
            print("Epoch: {}, Steps: {} | Train Loss: {:.3f}  vali_loss: {:.3f}   test_loss: {:.3f} ".format(epoch + 1, train_steps, train_loss,  vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args)
        torch.cuda.empty_cache()
    def shap_analysis(self):
        print("Starting SHAP analysis...")

        try:
            # Get the training data
            train_data, train_loader = self._get_data(flag='train')
            test_data, test_loader = self._get_data(flag='test')
            print(f"Train data size: {len(train_data)}")

            # Convert train_data to a tensor if it's not already
            train_data_tensor = torch.tensor(train_data.data_x, dtype=torch.float32).to(self.device)
            test_data_tensor = torch.tensor(test_data.data_x, dtype=torch.float32).to(self.device)
            print(f"Train data tensor shape: {train_data_tensor.shape}")

            # Ensure the data has the correct shape (samples, features)
            if len(train_data_tensor.shape) != 2:
                raise ValueError(f"Unexpected shape of train_data_tensor: {train_data_tensor.shape}")

            # Get the dimensions
            n_samples, n_features = train_data_tensor.shape
            print(f"Data dimensions: samples={n_samples}, features={n_features}")

            # Convert the data to numpy
            train_data_2d = train_data_tensor.cpu().numpy()
            train_data_2d_inverse = train_data.inverse_transform(train_data_2d)
            print(f"Inverse transformed train_data_2d shape: {train_data_2d_inverse.shape}")

            # Select 1000 samples randomly
            sample_size = min(200, n_samples)
            sample_indices = np.random.choice(n_samples, sample_size, replace=False)
            sample_data = train_data_2d_inverse[sample_indices]
            print(f"Sample data shape: {sample_data.shape}")


            # Use a subset of sample data as background
            background_size = min(200, sample_size)
            background = sample_data[:background_size]
            print(f"Background shape: {background.shape}")

            # Ensure the model is in evaluation mode
            self.model.eval()

            # Create a Deep SHAP explainer
            print("Creating DeepExplainer...")
            explainer = shap.DeepExplainer(self.model, background)
            print("DeepExplainer created successfully.")

            # Calculate SHAP values for sample data
            print("Calculating SHAP values...")
            shap_values = explainer.shap_values(sample_data)
            print("SHAP values calculated successfully.")

            print(f"SHAP values shape: {np.array(shap_values).shape}")

            # Handle multi-output case
            if isinstance(shap_values, list):
                shap_values = np.array(shap_values)

            if len(shap_values.shape) == 3:  # (output_dim, samples, features)
                shap_values = np.mean(shap_values, axis=0)  # Take the mean across output dimensions

            print(f"SHAP values shape after processing: {shap_values.shape}")

            # Calculate mean absolute SHAP values
            shap_values_mean_abs = np.mean(np.abs(shap_values), axis=0)
            shap_values_mean = np.mean(shap_values, axis=0)

            # Get feature names
            df_raw = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
            cols_data = df_raw.columns[2:-1]  # Exclude the last column (TB incidence)

            df_data = df_raw[cols_data]

            feature_names = df_data.columns.tolist()
            print(f"All feature names: {feature_names}")

            # Sort features by importance, excluding the last feature (TB incidence)
            feature_importance_order = np.argsort(shap_values_mean_abs[:-1])[::-1]
            top_10_features = feature_importance_order[:10]
            print(f"Top 10 features: {top_10_features}")

            # Print feature importance ranking
            print("Feature importance ranking and SHAP values:")
            print("Index: Feature Name - Mean Absolute SHAP Value")
            for i, idx in enumerate(feature_importance_order):
                shap_value_abs = shap_values_mean_abs[idx]
                print(f"{i + 1}: {feature_names[idx]} - shap_value_abs: {shap_value_abs:.6f} - shap_value: {shap_values_mean[idx]:.6f}")

            # Print SHAP values for each feature
            print("\nSHAP values for each feature:")
            for feature_idx in range(n_features - 1):  # Exclude the last feature
                shap_value = shap_values_mean[feature_idx]
                print(f"Feature: {feature_names[feature_idx]}, SHAP Value: {shap_value:.6f}")

            # Plot SHAP summary plot (violin plot)
            print("Plotting SHAP summary plot (violin)...")
            plt.figure(figsize=(12, 8))
            shap.summary_plot(shap_values[:, :-1], sample_data[:, :-1], 
                            feature_names=feature_names,
                            plot_type="violin", show=False)

            plt.xlabel("SHAP value (impact on model output)", family='Times New Roman', fontsize=14)
            plt.rc('font', family='Times New Roman', size=15)
            plt.tight_layout()
            plt.savefig('shap_summary_plot-1.png')
            plt.close()
            print("SHAP summary plot saved as 'shap_summary_plot-1.png'")

            # Plot feature importance
            print("Plotting feature importance...")
            plt.figure(figsize=(12, 8))
            top_20_features = feature_importance_order[:20]  # Top 20 features
            top_20_feature_names = [feature_names[idx] for idx in top_20_features]
            plt.barh(top_20_feature_names, shap_values_mean_abs[top_20_features])
            plt.xlabel("mean(|SHAP value|) (average impact on model output magnitude)", fontfamily='Times New Roman', fontsize=14)
            plt.ylabel("Features", fontfamily='Times New Roman', fontsize=14)
            plt.title("Feature Importance", fontfamily='Times New Roman', fontsize=16)
            plt.yticks(fontfamily='Times New Roman', fontsize=12)
            plt.xticks(fontfamily='Times New Roman', fontsize=12)
            plt.gca().invert_yaxis()
            plt.tight_layout()
            plt.savefig('feature_importance_plot.png')
            plt.close()
            print("Feature importance plot saved as 'feature_importance_plot.png'")

            # Select top 10 features
            top_10_feature_names = [feature_names[i] for i in top_10_features]
            print(f"Top 10 feature names: {top_10_feature_names}")

            # Additional diagnostic functions
            self._check_shap_values_distribution(shap_values)
            self._check_model_predictions(sample_data)

            # Plot SHAP dependence plots for top features
            print("Plotting SHAP dependence plots for top features...")
            for feature_idx in top_10_features:
                plt.figure(figsize=(10, 6))
                sanitized_feature_name = re.sub(r'\W+', '_', feature_names[feature_idx])  # Sanitize feature name
                shap.dependence_plot(feature_idx, shap_values, sample_data, feature_names=feature_names, show=False)
                plt.tight_layout()
                plt.savefig(f'shap_dependence_plot_{sanitized_feature_name}.png')
                plt.close()
                print(f"SHAP dependence plot for {feature_names[feature_idx]} saved as 'shap_dependence_plot_{sanitized_feature_name}.png'")

        except Exception as e:
            print(f"An error occurred during SHAP analysis: {e}")

    def _check_shap_values_distribution(self, shap_values):
        print("Checking SHAP values distribution...")
        shap_values_flat = shap_values.flatten()
        plt.figure(figsize=(10, 6))
        plt.hist(shap_values_flat, bins=50, color='blue', alpha=0.7)
        plt.title("Distribution of SHAP Values")
        plt.xlabel("SHAP Value")
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('shap_values_distribution.png')
        plt.close()
        print("SHAP values distribution plot saved as 'shap_values_distribution.png'")

    def _check_model_predictions(self, sample_data):
        print("Checking model predictions...")
        with torch.no_grad():
            sample_data_tensor = torch.tensor(sample_data, dtype=torch.float32).to(self.device)
            predictions = self.model(sample_data_tensor).cpu().numpy()
            predictions=predictions.inverse_transform(predictions)
        plt.figure(figsize=(10, 6))
        plt.hist(predictions.flatten(), bins=50, color='green', alpha=0.7)
        plt.title("Distribution of Model Predictions")
        plt.xlabel("Prediction")
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('model_predictions_distribution.png')
        plt.close()
        print("Model predictions distribution plot saved as 'model_predictions_distribution.png'")


    def test(self, setting, test=1):
        test_data, test_loader = self._get_data(flag='test')
        path = os.path.join(self.args.checkpoints, setting)
        if test:
            print('loading model')
            self.model.load_state_dict(torch.load(os.path.join(path, 'checkpoint.pth')))
        
        head = f'./test_dict/{self.args.data_path[:-4]}/{self.args.seq_len}_to_{self.args.pred_len}/'
        tail = f'{self.args.model}/{self.args.loss}/bz_{self.args.batch_size}/lr_{self.args.learning_rate}/'
        dict_path = head + tail
        
        if not os.path.exists(dict_path):
            os.makedirs(dict_path)

        self.model.eval()
        
        with torch.no_grad():
            preds = []
            trues = []
            
            for i, (batch_x, batch_y) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:, :].float()
                
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                else:
                    outputs = self.model(batch_x)
                
                outputs = outputs.detach().cpu().numpy()
                batch_y = batch_y.detach().numpy()

                preds.append(outputs)
                trues.append(batch_y)
            
            preds = np.concatenate(preds, axis=0)
            trues = np.concatenate(trues, axis=0)

            print('test shape:', preds.shape, trues.shape)

            # 获取原始数据的特征数量
            original_feature_num = test_data.data_x.shape[-1]

            # 确保预测和真实值的维度与原始数据匹配
            if preds.shape[-1] != original_feature_num:
                print(f"Warning: Prediction shape {preds.shape} does not match original data shape. Adjusting...")
                preds_adjusted = np.zeros((preds.shape[0], preds.shape[1], original_feature_num))
                preds_adjusted[:,:,-1] = preds.squeeze()  # 假设预测值对应最后一个特征
                preds = preds_adjusted

            if trues.shape[-1] != original_feature_num:
                print(f"Warning: True values shape {trues.shape} does not match original data shape. Adjusting...")
                trues_adjusted = np.zeros((trues.shape[0], trues.shape[1], original_feature_num))
                trues_adjusted[:,:,-1] = trues.squeeze()  # 假设真实值对应最后一个特征
                trues = trues_adjusted

            # Reshape the arrays into 2D
            preds = preds.reshape(-1, original_feature_num)
            trues = trues.reshape(-1, original_feature_num)

            # Inverse transform the predictions and truths
            preds_inverse = test_data.inverse_transform(preds)
            trues_inverse = test_data.inverse_transform(trues)

            # 只取目标变量的列
            target_col = test_data.data_y.shape[-1] - 1  # 假设目标变量是最后一列
            preds_inverse = preds_inverse[:, target_col]
            trues_inverse = trues_inverse[:, target_col]

            # Save to CSV
            results = pd.DataFrame({
                'Truth': trues_inverse.flatten(),
                'Prediction': preds_inverse.flatten()
            })
            results.to_csv(os.path.join(dict_path, 'test_results.csv'), index=False)

            print('Results saved to', os.path.join(dict_path, 'test_results.csv'))

            # 新增: 读取原始数据并进行匹配
            original_data = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
            target_values = original_data[self.args.target].values

            # Compare and extract exact matches based on value
            matching_indices = np.where(np.isin(trues_inverse, target_values))[0]
            matching_results = results.iloc[matching_indices]

            # Save matching results to new CSV
            matching_results.to_csv(os.path.join(dict_path, 'matching_test_results.csv'), index=False)

            # 计算性能指标
            mse, mae = mean_squared_error(trues, preds), mean_absolute_error(trues, preds)
            r2 = r2_score(trues, preds)
            print('mse: {:.3f}  mae: {:.3f}  r2: {:.3f}'.format(mse, mae, r2))

            # 保存性能指标到JSON文件
            my_dict = {
                'mse': "{:.3f}".format(mse),
                'mae': "{:.3f}".format(mae),
                'r2': "{:.3f}".format(r2)
            }
            with open(os.path.join(dict_path, 'records.json'), 'w') as f:
                json.dump(my_dict, f)

            # 清理GPU缓存
            torch.cuda.empty_cache()

        return



    def predict(self, setting, predict_data, predict_loader):            
        path = os.path.join(self.args.checkpoints, setting)
            
        print('loading model')
        self.model.load_state_dict(torch.load(os.path.join(path, 'checkpoint.pth')))
            
        self.model.eval()
            
        preds = []
        trues = []
            
        with torch.no_grad():
            for i, (batch_x, batch_y) in enumerate(predict_loader):
                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:, :].float()       
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                else:
                    outputs = self.model(batch_x)                   
                pred = outputs.detach().cpu().numpy()
                true = batch_y.detach().numpy()                    
                preds.append(pred)
                trues.append(true)            
        preds = np.concatenate(preds, axis=0)
        trues = np.concatenate(trues, axis=0)            
        preds = preds.reshape(-1, preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-1])
        mse = mean_squared_error(trues, preds)
        mae = mean_absolute_error(trues, preds)
        r2 = r2_score(trues, preds)
        print('Prediction Results:')
        print('MSE: {:.3f}, MAE: {:.3f}, R^2: {:.3f}'.format(mse, mae, r2))      
        return preds, trues, mse, mae, r2
    def shap_analysis(self):
        print("Starting SHAP analysis...")

        try:
            # Get the training data
            train_data, _ = self._get_data(flag='train')
            print(f"Train data size: {len(train_data)}")

            # Reduce sample size
            sample_size = min(100, len(train_data))
            indices = np.random.choice(len(train_data), sample_size, replace=False)

            sample_x = np.array([train_data[i][0] for i in indices])
            sample_y = np.array([train_data[i][1] for i in indices])

            print(f"Sample data shape: X - {sample_x.shape}, Y - {sample_y.shape}")

            # Reshape the input data to 2D
            sample_x_2d = sample_x.reshape(sample_x.shape[0], -1)
            print(f"Reshaped sample data shape: X - {sample_x_2d.shape}")

            # Use a subset of sample data as background
            background_size = min(100, sample_size)
            background = sample_x_2d[:background_size]
            print(f"Background shape: {background.shape}")

            # Ensure the model is in evaluation mode
            self.model.eval()

            # Define a function to make predictions
            def f(x):
                with torch.no_grad():
                    x_tensor = torch.tensor(x, dtype=torch.float32).to(self.device)
                    x_tensor = x_tensor.reshape(-1, sample_x.shape[1], sample_x.shape[2])
                    return self.model(x_tensor).cpu().numpy()

            # Create a DeepExplainer
            print("Creating DeepExplainer...")
            background_tensor = torch.tensor(background, dtype=torch.float32).to(self.device)
            background_tensor = background_tensor.reshape(-1, sample_x.shape[1], sample_x.shape[2])
            explainer = shap.DeepExplainer(self.model, background_tensor)
            print("DeepExplainer created successfully.")

            # Calculate SHAP values for sample data in batches
            print("Calculating SHAP values...")
            batch_size = 10
            shap_values = []
            for i in tqdm(range(0, sample_x_2d.shape[0], batch_size), desc="SHAP Calculation"):
                batch = torch.tensor(sample_x_2d[i:i+batch_size], dtype=torch.float32).to(self.device)
                batch = batch.reshape(-1, sample_x.shape[1], sample_x.shape[2])
                batch_shap_values = explainer.shap_values(batch)
                shap_values.append(batch_shap_values)

            # Concatenate the results
            if isinstance(shap_values[0], list):
                # If multi-output, concatenate each output separately
                shap_values = [np.concatenate([batch[i] for batch in shap_values], axis=0) for i in range(len(shap_values[0]))]
            else:
                shap_values = np.concatenate(shap_values, axis=0)

            print("SHAP values calculated successfully.")
            print(f"SHAP values shape: {np.array(shap_values).shape}")

            # Handle multi-output case
            if isinstance(shap_values, list):
                shap_values = np.array(shap_values)

            if len(shap_values.shape) == 3:  # (output_dim, samples, features)
                shap_values = np.mean(shap_values, axis=0)  # Take the mean across output dimensions

            print(f"SHAP values shape after processing: {shap_values.shape}")

            # Ensure shap_values and sample_x_2d have the same number of features
            if shap_values.shape[1] != sample_x_2d.shape[1]:
                print(f"Warning: SHAP values shape ({shap_values.shape}) does not match sample data shape ({sample_x_2d.shape})")
                min_features = min(shap_values.shape[1], sample_x_2d.shape[1])
                shap_values = shap_values[:, :min_features]
                sample_x_2d = sample_x_2d[:, :min_features]
                print(f"Adjusted shapes: SHAP values - {shap_values.shape}, Sample data - {sample_x_2d.shape}")

            # Calculate mean absolute SHAP values
            shap_values_mean_abs = np.abs(np.mean(shap_values, axis=0))
            shap_values_mean = np.mean(shap_values, axis=0)

            # Get feature names
            df_raw = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
            cols_data = df_raw.columns[2:]  # Include all columns
            df_data = df_raw[cols_data]
            feature_names = df_data.columns.tolist()
            print(f"All feature names: {feature_names}")

            # Ensure feature_names matches the number of features in shap_values
            feature_names = feature_names[:shap_values.shape[1]]
            print(f"Adjusted feature names: {feature_names}")

            # Sort features by importance
            feature_importance_order = np.argsort(shap_values_mean_abs)[::-1]
            top_10_features = feature_importance_order[:10]
            print(f"Top 10 features by importance: {top_10_features}")

            # Print feature importance ranking
            print("Feature importance ranking and SHAP values:")
            print("Index: Feature Name - Mean Absolute SHAP Value - Mean SHAP Value")
            for i, idx in enumerate(feature_importance_order):
                shap_value_abs = shap_values_mean_abs[idx]
                shap_value = shap_values_mean[idx]
                print(f"{i + 1}: {feature_names[idx]} - Abs SHAP: {shap_value_abs:.6f} - SHAP: {shap_value:.6f}")

            # Plot feature importance
            print("Plotting feature importance...")
            plt.figure(figsize=(12, 8))
            top_20_features = feature_importance_order[:20]  # Top 20 features
            top_20_feature_names = [feature_names[idx] for idx in top_20_features]
            plt.barh(top_20_feature_names, shap_values_mean_abs[top_20_features])
            plt.xlabel("Mean Absolute SHAP Value", fontfamily='Times New Roman', fontsize=14)
            plt.ylabel("Features", fontfamily='Times New Roman', fontsize=14)
            plt.title("Feature Importance by Mean Absolute SHAP Value", fontfamily='Times New Roman', fontsize=16)
            plt.yticks(fontfamily='Times New Roman', fontsize=12)
            plt.xticks(fontfamily='Times New Roman', fontsize=12)
            plt.gca().invert_yaxis()
            plt.tight_layout()
            plt.savefig('feature_importance_plot.png')
            plt.close()
            print("Feature importance plot saved as 'feature_importance_plot.png'")

            # Plot SHAP summary plot (violin plot)
            print("Plotting SHAP summary plot (violin)...")
            plt.figure(figsize=(12, 8))
            shap.summary_plot(shap_values, sample_x_2d,
                            feature_names=feature_names,
                            plot_type="violin", show=False)

            plt.xlabel("SHAP value (impact on model output)", family='Times New Roman', fontsize=14)
            plt.rc('font', family='Times New Roman', size=15)
            plt.tight_layout()
            plt.savefig('shap_summary_plot.png')
            plt.close()
            print("SHAP summary plot saved as 'shap_summary_plot.png'")

        except Exception as e:
            print(f"An error occurred during SHAP analysis: {e}")
